#' Format MR results for forest plot
#'
#' This function takes the results from [mr()] and is particularly useful
#' if the MR has been applied using multiple exposures and multiple outcomes.
#' It creates a new data frame with the following:
#' \itemize{
#' \item Variables: exposure, outcome, category, outcome sample size, effect, upper ci, lower ci, pval, nsnp
#' \item only one estimate for each exposure-outcome
#' \item exponentiated effects if required
#' }
#'
#' By default it uses the [available_outcomes()] function to retrieve the study level characteristics for the outcome trait,
#' including sample size and outcome category.
#' This assumes the MR analysis was performed using outcome GWAS(s) contained in OpenGWAS.
#'
#' If \code{ao_slc} is set to \code{TRUE} then the user must supply their own study level characteristics.
#' This is useful when the user has supplied their own outcome GWAS results (i.e. they are not in OpenGWAS).
#'
#' @param mr_res Results from [mr()].
#' @param exponentiate Convert effects to OR? The default is `FALSE`.
#' @param single_snp_method Which of the single SNP methods to use when only 1 SNP was used to estimate the causal effect? The default is `"Wald ratio"`.
#' @param multi_snp_method Which of the multi-SNP methods to use when there was more than 1 SNPs used to estimate the causal effect? The default is `"Inverse variance weighted"`.
#' @param ao_slc Logical; retrieve sample size and subcategory using [available_outcomes()]. If set to `FALSE` `mr_res` must contain the following additional columns: `subcategory` and `sample_size`.
#' @param priority Name of category to prioritise at the top of the forest plot. The default is `"Cardiometabolic"`.
#'
#' @export
#' @return data frame.
format_mr_results <- function(
  mr_res,
  exponentiate = FALSE,
  single_snp_method = "Wald ratio",
  multi_snp_method = "Inverse variance weighted",
  ao_slc = TRUE,
  priority = "Cardiometabolic"
) {
  # Get extra info on outcomes
  if (ao_slc) {
    ao <- available_outcomes()
    ao$subcategory[ao$subcategory == "Cardiovascular"] <- "Cardiometabolic"
    ao$subcategory[ao$trait == "Type 2 diabetes"] <- "Cardiometabolic"
    names(ao)[names(ao) == "nsnp"] <- "nsnp.array"
  }

  dat <- subset(
    mr_res,
    (nsnp == 1 & method == single_snp_method) | (nsnp > 1 & method == multi_snp_method)
  )
  dat$index <- seq_len(nrow(dat))

  if (ao_slc) {
    dat <- merge(dat, ao, by.x = "id.outcome", by.y = "id")
  }
  dat <- dat[order(dat$b), ]

  # Create CIs
  dat$up_ci <- as.numeric(dat$b) + 1.96 * as.numeric(dat$se)
  dat$lo_ci <- as.numeric(dat$b) - 1.96 * as.numeric(dat$se)

  # Exponentiate?
  if (exponentiate) {
    dat$b <- exp(as.numeric(dat$b))
    dat$up_ci <- exp(dat$up_ci)
    dat$lo_ci <- exp(dat$lo_ci)
  }

  # Organise cats
  dat$subcategory <- as.factor(dat$subcategory)

  #generate a simple trait column. this contains only the outcome name (ie excludes consortium and year from the outcome column generated by mr()). This step caters to the possibility that a user's results contain a mixture of results obtained via MR-Base and correspondence. The later won't be present in the MR-Base database. However, still need to split the outcome name into trait, year and consortium.
  if (!ao_slc) {
    dat$trait <- dat$outcome
    Pos <- grep("\\|\\|", dat$trait) #this indicates the outcome column was derived from data in MR-Base. Sometimes it wont look like this e.g. if the user has supplied their own outcomes
    if (sum(Pos) != 0) {
      Outcome <- dat$trait[Pos]
      Outcome <- unlist(strsplit(Outcome, split = "\\|\\|"))
      Outcome <- Outcome[seq(1, length(Outcome), by = 3)]
      Outcome <- trim(Outcome)
      dat$trait[Pos] <- Outcome
    }
  }
  dat <- data.frame(
    exposure = as.character(dat$exposure),
    outcome = as.character(dat$trait),
    category = as.character(dat$subcategory),
    effect = dat$b,
    up_ci = dat$up_ci,
    lo_ci = dat$lo_ci,
    nsnp = dat$nsnp,
    pval = dat$pval,
    sample_size = dat$sample_size,
    index = dat$index,
    stringsAsFactors = FALSE
  )

  # if (fix_capitals)
  # {
  # 	dat$exposure <- simple_cap(dat$exposure)
  # 	dat$outcome <- simple_cap(dat$outcome)
  # 	dat$category <- simple_cap(dat$category)
  # }

  # Fill in missing values
  exps <- unique(dat$exposure)
  dat <- plyr::ddply(dat, c("outcome"), function(x) {
    x <- plyr::mutate(x)
    nc <- ncol(x)
    missed <- exps[!exps %in% x$exposure]
    if (length(missed) >= 1) {
      out <- unique(x$outcome)
      ca <- unique(x$category)
      n <- unique(x$sample_size)
      md <- data.frame(
        exposure = missed,
        outcome = out,
        category = ca,
        sample_size = n,
        stringsAsFactors = FALSE
      )
      x <- plyr::rbind.fill(x, md)
    }
    return(x)
  })
  # dat <- dplyr::group_by(dat, outcome) %>%
  # 	dplyr::do({
  # 		x <- .
  # 		nc <- ncol(x)
  # 		missed <- exps[! exps %in% x$exposure]
  # 		if (length(missed) >= 1)
  # 		{
  # 			out <- unique(x$outcome)
  # 			ca <- unique(x$category)
  # 			n <- unique(x$sample_size)
  # 			md <- data.frame(exposure = missed, outcome=out, category=ca, sample_size=n, stringsAsFactors=FALSE)
  # 			x <- dplyr::bind_rows(x, md)
  # 		}
  # 		return(x)
  # 	}) %>% as.data.frame(stringsAsFactors=FALSE)

  dat <- dat[order(dat$index), ]

  dat <- dat[order(dat$outcome), ]

  stopifnot(length(priority) == 1)

  if (priority %in% dat$category) {
    temp1 <- subset(dat, category == priority)
    temp2 <- subset(dat, category == "Other")
    dat <- rbind(
      subset(dat, category == priority),
      subset(dat, !category %in% c(priority, "Other")),
      subset(dat, category == "Other")
    )
  }

  return(dat)
}

#' Simple attempt at correcting string case
#'
#' @param x Character or array of character
#'
#' @keywords internal
#' @return Character or array of character
simple_cap <- function(x) {
  sapply(x, function(x) {
    x <- tolower(x)
    s <- strsplit(x, " ")[[1]]
    paste(toupper(substring(s, 1, 1)), substring(s, 2), sep = "", collapse = " ")
  })
}

#' Trim function to remove leading and trailing blank spaces
#'
#' @param x Character or array of character
#'
#' @export
#' @return Character or array of character
trim <- function(x) {
  gsub("(^[[:space:]]+|[[:space:]]+$)", "", x)
}


#' Create fixed width label
#'
#' @param n1 number
#' @param nom name
#'
#' @keywords internal
#' @return text
create_label <- function(n1, nom) {
  len_n1 <- max(nchar(n1), na.rm = TRUE)
  n1_c <- formatC(n1, width = len_n1)

  l <- nchar(nom)
  len_nom <- max(l)
  p <- paste0("%-", len_nom, "s")
  nomp <- sprintf(p, nom)

  out <- paste0(n1_c, "    ", nomp)
  out <- factor(out, levels = unique(out))
  return(out)
}

#' A basic forest plot
#'
#' This function is used to create a basic forest plot.
#' It requires the output from [format_mr_results()].
#'
#' @param dat Output from [format_mr_results()].
#' @param section Which category in dat to plot. If `NULL` then prints everything.
#' @param colour_group Which exposure to plot. If `NULL` then prints everything grouping by colour.
#' @param colour_group_first The default is `TRUE`.
#' @param xlab x-axis label. Default=`NULL`.
#' @param bottom Show x-axis? Default=`FALSE`.
#' @param trans Transformation of x axis.
#' @param xlim x-axis limits.
#' @param threshold p-value threshold to use for colouring points by significance level. If `NULL` (default) then colour layer won't be applied.
#'
#' @return ggplot object
#' @keywords internal
forest_plot_basic <- function(
  dat,
  section = NULL,
  colour_group = NULL,
  colour_group_first = TRUE,
  xlab = NULL,
  bottom = TRUE,
  trans = "identity",
  xlim = NULL,
  threshold = NULL
) {
  if (bottom) {
    text_colour <- ggplot2::element_text(colour = "black")
    tick_colour <- ggplot2::element_line(colour = "black")
    xlabname <- xlab
  } else {
    text_colour <- ggplot2::element_blank()
    tick_colour <- ggplot2::element_blank()
    xlabname <- NULL
  }

  # OR or log(OR)?
  # If CI are symmetric then log(OR)
  # Use this to guess where to put the null line
  null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

  # Change lab
  if (!is.null(xlim)) {
    stopifnot(length(xlim) == 2)
    stopifnot(xlim[1] < xlim[2])
    dat$lo_ci <- pmax(dat$lo_ci, xlim[1], na.rm = TRUE)
    dat$up_ci <- pmin(dat$up_ci, xlim[2], na.rm = TRUE)
  }

  up <- max(dat$up_ci, na.rm = TRUE)
  lo <- min(dat$lo_ci, na.rm = TRUE)
  r <- up - lo
  lo_orig <- lo
  lo <- lo - r * 0.5

  if (!is.null(section)) {
    dat <- subset(dat, category == section)
    main_title <- section
  } else {
    main_title <- ""
  }

  if (!is.null(colour_group)) {
    dat <- subset(dat, exposure == colour_group)
    if (!is.null(threshold)) {
      point_plot <- ggplot2::geom_point(size = 2, ggplot2::aes(colour = pval < threshold))
    } else {
      point_plot <- ggplot2::geom_point(size = 2)
    }
  } else {
    if (!is.null(threshold)) {
      point_plot <- ggplot2::geom_point(
        ggplot2::aes(colour = exposure, shape = pval < threshold),
        size = 2
      )
    } else {
      point_plot <- ggplot2::geom_point(ggplot2::aes(colour = exposure), size = 2)
    }
  }

  if ((!is.null(colour_group) && colour_group_first) || is.null(colour_group)) {
    outcome_labels <- ggplot2::geom_text(
      ggplot2::aes(label = outcome),
      x = lo,
      y = mean(c(1, length(unique(dat$exposure)))),
      hjust = 0,
      vjust = 0.5,
      size = 2.5
    )
    main_title <- ifelse(is.null(section), "", section)
    title_colour <- "black"
  } else {
    outcome_labels <- NULL
    lo <- lo_orig
    main_title <- ""
    title_colour <- "white"
  }

  main_title <- section

  if (!"lab" %in% names(dat)) {
    dat$lab <- create_label(dat$sample_size, dat$outcome)
  }

  l <- data.frame(lab = sort(unique(dat$lab)), col = "a", stringsAsFactors = FALSE)
  l$col[seq_len(nrow(l)) %% 2 == 0] <- "b"

  dat <- merge(dat, l, by = "lab", all.x = TRUE)
  dat <- dat[rev(seq_len(nrow(dat))), ]

  if (utils::packageVersion("ggplot2") <= "3.5.2") {
    p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
      ggplot2::geom_rect(
        ggplot2::aes(fill = col),
        xmin = -Inf,
        xmax = Inf,
        ymin = -Inf,
        ymax = Inf
      ) +
      ggplot2::geom_vline(
        xintercept = seq(ceiling(lo_orig), ceiling(up), by = 0.5),
        colour = "white",
        size = 0.3
      ) +
      ggplot2::geom_vline(xintercept = null_line, colour = "#333333", size = 0.3) +
      ggplot2::geom_errorbarh(
        ggplot2::aes(xmin = lo_ci, xmax = up_ci),
        height = 0,
        size = 0.4,
        colour = "#aaaaaa"
      ) +
      ggplot2::geom_point(colour = "black", size = 2.2) +
      point_plot +
      ggplot2::facet_grid(lab ~ .) +
      ggplot2::scale_x_continuous(trans = trans, limits = c(lo, up)) +
      ggplot2::scale_colour_brewer(type = "qual") +
      ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
      ggplot2::theme(
        axis.line = ggplot2::element_blank(),
        axis.text.y = ggplot2::element_blank(),
        axis.ticks.y = ggplot2::element_blank(),
        axis.text.x = text_colour,
        axis.ticks.x = tick_colour,
        # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
        strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
        strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 9),
        legend.position = "none",
        legend.direction = "vertical",
        panel.grid.minor.x = ggplot2::element_blank(),
        panel.grid.minor.y = ggplot2::element_blank(),
        panel.grid.major.y = ggplot2::element_blank(),
        plot.title = ggplot2::element_text(hjust = 0, size = 12, colour = title_colour),
        plot.margin = ggplot2::unit(c(2, 3, 2, 0), units = "points"),
        plot.background = ggplot2::element_rect(fill = "white"),
        panel.spacing = ggplot2::unit(0, "lines"),
        panel.background = ggplot2::element_rect(colour = "red", fill = "grey", size = 1),
        strip.text.y = ggplot2::element_blank()
        # strip.background = ggplot2::element_blank()
      ) +
      ggplot2::labs(y = NULL, x = xlabname, colour = "", fill = NULL, title = main_title) +
      outcome_labels
  } else {
    p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
      ggplot2::geom_rect(
        ggplot2::aes(fill = col),
        xmin = -Inf,
        xmax = Inf,
        ymin = -Inf,
        ymax = Inf
      ) +
      ggplot2::geom_vline(
        xintercept = seq(ceiling(lo_orig), ceiling(up), by = 0.5),
        colour = "white",
        size = 0.3
      ) +
      ggplot2::geom_vline(xintercept = null_line, colour = "#333333", size = 0.3) +
      ggplot2::geom_errorbar(
        ggplot2::aes(xmin = lo_ci, xmax = up_ci),
        width = 0,
        size = 0.4,
        colour = "#aaaaaa",
        orientation = "y"
      ) +
      ggplot2::geom_point(colour = "black", size = 2.2) +
      point_plot +
      ggplot2::facet_grid(lab ~ .) +
      ggplot2::scale_x_continuous(trans = trans, limits = c(lo, up)) +
      ggplot2::scale_colour_brewer(type = "qual") +
      ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
      ggplot2::theme(
        axis.line = ggplot2::element_blank(),
        axis.text.y = ggplot2::element_blank(),
        axis.ticks.y = ggplot2::element_blank(),
        axis.text.x = text_colour,
        axis.ticks.x = tick_colour,
        # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
        strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
        strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 9),
        legend.position = "none",
        legend.direction = "vertical",
        panel.grid.minor.x = ggplot2::element_blank(),
        panel.grid.minor.y = ggplot2::element_blank(),
        panel.grid.major.y = ggplot2::element_blank(),
        plot.title = ggplot2::element_text(hjust = 0, size = 12, colour = title_colour),
        plot.margin = ggplot2::unit(c(2, 3, 2, 0), units = "points"),
        plot.background = ggplot2::element_rect(fill = "white"),
        panel.spacing = ggplot2::unit(0, "lines"),
        panel.background = ggplot2::element_rect(colour = "red", fill = "grey", size = 1),
        strip.text.y = ggplot2::element_blank()
        # strip.background = ggplot2::element_blank()
      ) +
      ggplot2::labs(y = NULL, x = xlabname, colour = "", fill = NULL, title = main_title) +
      outcome_labels
  }
  return(p)
}


forest_plot_names <- function(dat, section = NULL, bottom = TRUE) {
  if (bottom) {
    text_colour <- ggplot2::element_text(colour = "white")
    tick_colour <- ggplot2::element_line(colour = "white")
    xlabname <- ""
  } else {
    text_colour <- ggplot2::element_blank()
    tick_colour <- ggplot2::element_blank()
    xlabname <- NULL
  }

  # OR or log(OR)?
  # If CI are symmetric then log(OR)
  # Use this to guess where to put the null line
  null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

  # up <- max(dat$up_ci, na.rm=TRUE)
  # lo <- min(dat$lo_ci, na.rm=TRUE)
  # r <- up-lo
  # lo_orig <- lo
  # lo <- lo - r * 0.5
  lo <- 0
  up <- 1

  if (!is.null(section)) {
    dat <- subset(dat, category == section)
    main_title <- section
    section_colour <- "black"
  } else {
    main_title <- section
    section_colour <- "white"
  }

  point_plot <- ggplot2::geom_point(ggplot2::aes(colour = exposure), size = 2)

  outcome_labels <- ggplot2::geom_text(
    ggplot2::aes(label = outcome),
    x = lo,
    y = mean(c(1, length(unique(dat$exposure)))),
    hjust = 0,
    vjust = 0.5,
    size = 3.5
  )
  main_title <- section

  if (!"lab" %in% names(dat)) {
    dat$lab <- create_label(dat$sample_size, dat$outcome)
  }

  l <- data.frame(lab = sort(unique(dat$lab)), col = "a", stringsAsFactors = FALSE)
  l$col[seq_len(nrow(l)) %% 2 == 0] <- "b"

  dat <- merge(dat, l, by = "lab", all.x = TRUE)

  p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
    ggplot2::geom_rect(ggplot2::aes(fill = col), xmin = -Inf, xmax = Inf, ymin = -Inf, ymax = Inf) +
    ggplot2::facet_grid(lab ~ .) +
    ggplot2::scale_x_continuous(limits = c(lo, up)) +
    ggplot2::scale_colour_brewer(type = "qual") +
    ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
    ggplot2::theme(
      axis.line = ggplot2::element_blank(),
      axis.text.y = ggplot2::element_blank(),
      axis.ticks.y = ggplot2::element_blank(),
      axis.text.x = text_colour,
      axis.ticks.x = tick_colour,
      # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
      strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
      strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 11),
      legend.position = "none",
      legend.direction = "vertical",
      panel.grid.minor.x = ggplot2::element_blank(),
      panel.grid.minor.y = ggplot2::element_blank(),
      panel.grid.major.y = ggplot2::element_blank(),
      plot.title = ggplot2::element_text(hjust = 0, size = 12, colour = section_colour),
      plot.margin = ggplot2::unit(c(2, 0, 2, 0), units = "points"),
      plot.background = ggplot2::element_rect(fill = "white"),
      panel.spacing = ggplot2::unit(0, "lines"),
      panel.background = ggplot2::element_rect(colour = "red", fill = "grey", size = 1),
      strip.text.y = ggplot2::element_blank()
      # strip.background = ggplot2::element_blank()
    ) +
    ggplot2::labs(y = NULL, x = xlabname, colour = "", fill = NULL, title = main_title) +
    outcome_labels
  return(p)
}


#' Forest plot for multiple exposures and multiple outcomes
#'
#' Perform MR of multiple exposures and multiple outcomes. This plots the results.
#'
#' @param mr_res Results from [mr()].
#' @param exponentiate Convert effects to OR? Default is `FALSE`.
#' @param single_snp_method Which of the single SNP methods to use when only 1 SNP was used to estimate the causal effect? The default is `"Wald ratio"`.
#' @param multi_snp_method Which of the multi-SNP methods to use when there was more than 1 SNPs used to estimate the causal effect? The default is `"Inverse variance weighted"`.
#' @param group_single_categories If there are categories with only one outcome, group them together into an "Other" group. The default is `TRUE`.
#' @param by_category Separate the results into sections by category? The default is `TRUE`.
#' @param in_columns Separate the exposures into different columns. The default is `FALSE`.
#' @param threshold p-value threshold to use for colouring points by significance level. If `NULL` (default) then colour layer won't be applied.
#' @param xlab x-axis label. If `in_columns=TRUE` then the exposure values are appended to the end of `xlab`. e.g. if `xlab="Effect of"` then x-labels will read `"Effect of exposure1"`, `"Effect of exposure2"` etc. Otherwise will be printed as is.
#' @param xlim limit x-axis range. Provide vector of length 2, with lower and upper bounds. The default is `NULL`.
#' @param trans Transformation to apply to x-axis. e.g. `"identity"`, `"log2"`, etc. The default is `"identity"`.
#' @param ao_slc retrieve sample size and subcategory from [available_outcomes()]. If set to `FALSE` then `mr_res` must contain the following additional columns: `sample_size` and `subcategory`. The default behaviour is to use [available_outcomes()] to retrieve sample size and subcategory.
#' @param priority Name of category to prioritise at the top of the forest plot. The default is `"Cardiometabolic"`.
#'
#' @export
#' @return grid plot object
forest_plot <- function(
  mr_res,
  exponentiate = FALSE,
  single_snp_method = "Wald ratio",
  multi_snp_method = "Inverse variance weighted",
  group_single_categories = TRUE,
  by_category = TRUE,
  in_columns = FALSE,
  threshold = NULL,
  xlab = "",
  xlim = NULL,
  trans = "identity",
  ao_slc = TRUE,
  priority = "Cardiometabolic"
) {
  dat <- format_mr_results(
    mr_res,
    exponentiate = exponentiate,
    single_snp_method = single_snp_method,
    multi_snp_method = multi_snp_method,
    # group_single_categories=group_single_categories,
    ao_slc = ao_slc,
    priority = priority
  )
  if (group_single_categories) {
    temp <- subset(dat, !duplicated(outcome))
    tab <- table(temp$subcategory)
    othercats <- names(tab)[tab == 1]
    if (length(othercats) > 1) {
      levels(dat$subcategory)[levels(dat$subcategory) %in% othercats] <- "Other"
      dat$subcategory <- factor(dat$subcategory)
    }
  }

  dat$lab <- create_label(dat$sample_size, dat$outcome)

  legend <- cowplot::get_legend(
    ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = outcome)) +
      ggplot2::geom_point(ggplot2::aes(colour = exposure)) +
      ggplot2::scale_colour_brewer(type = "qual") +
      ggplot2::labs(colour = "Exposure") +
      ggplot2::theme(text = ggplot2::element_text(size = 10))
  )

  if (!by_category) {
    if (!in_columns) {
      return(
        forest_plot_basic(
          dat,
          bottom = TRUE,
          xlab = xlab,
          trans = trans,
          xlim = xlim,
          threshold = threshold
        ) +
          ggplot2::theme(legend.position = "left")
      )
    } else {
      l <- list()
      l[[1]] <- forest_plot_names(
        dat,
        section = NULL,
        bottom = TRUE
      )
      count <- 2
      columns <- unique(dat$exposure)
      for (i in seq_along(columns)) {
        l[[count]] <- forest_plot_basic(
          dat,
          section = NULL,
          bottom = TRUE,
          colour_group = columns[i],
          colour_group_first = FALSE,
          xlab = paste0(xlab, " ", columns[i]),
          trans = trans,
          xlim = xlim,
          threshold = threshold
        )
        count <- count + 1
      }
      return(
        cowplot::plot_grid(
          gridExtra::arrangeGrob(
            grobs = l,
            ncol = length(columns) + 1,
            nrow = 1,
            widths = c(4, rep(5, length(columns)))
          )
        )
      )
    }
  }

  if (!in_columns) {
    sec <- unique(as.character(dat$category))
    h <- rep(0, length(sec))
    l <- list()
    for (i in seq_along(sec)) {
      l[[i]] <- forest_plot_basic(
        dat,
        sec[i],
        bottom = i == length(sec),
        xlab = xlab,
        trans = trans,
        xlim = xlim,
        threshold = threshold
      )
      h[i] <- length(unique(subset(dat, category == sec[i])$outcome))
    }
    h <- h + 1
    h[length(sec)] <- h[length(sec)] + 1

    return(
      cowplot::plot_grid(
        gridExtra::arrangeGrob(
          legend,
          gridExtra::arrangeGrob(grobs = l, ncol = 1, nrow = length(h), heights = h),
          ncol = 2,
          nrow = 1,
          widths = c(1, 5)
        )
      )
    )
  } else {
    sec <- unique(as.character(dat$category))
    columns <- unique(dat$exposure)
    l <- list()
    h <- rep(0, length(sec))
    count <- 1
    for (i in seq_along(sec)) {
      h[i] <- length(unique(subset(dat, category == sec[i])$outcome))
      l[[count]] <- forest_plot_names(
        dat,
        sec[i],
        bottom = i == length(sec)
      )
      count <- count + 1
      for (j in seq_along(columns)) {
        l[[count]] <- forest_plot_basic(
          dat,
          sec[i],
          bottom = i == length(sec),
          colour_group = columns[j],
          colour_group_first = FALSE,
          xlab = paste0(xlab, " ", columns[j]),
          trans = trans,
          xlim = xlim,
          threshold = threshold
        )
        count <- count + 1
      }
    }
    h <- h + 1
    h[length(sec)] <- h[length(sec)] + 1

    return(
      cowplot::plot_grid(
        gridExtra::arrangeGrob(
          grobs = l,
          ncol = length(columns) + 1,
          nrow = length(h),
          heights = h,
          widths = c(4, rep(5, length(columns)))
        )
      )
    )
  }
}
